package com.simon.study.algorithm.leetcode;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * <p>
 *
 * @author simon
 */
public class MinSumSquareDiff02333Dup2 {
    public static void main(String[] args) {
        int[] nums1 = new int[]{0,0,100000,0,100000,0,100000,100000,0,100000,0};
        int[] nums2 = new int[]{100000,100000,0,100000,0,100000,0,0,100000,0,100000};
        int k1 = 0, k2 = 0;
        minSumSquareDiff(nums1, nums2, k1, k2);
    }

    static class Pair{
        Integer num;
        int times;

        public Pair(){}

        public Pair(Integer num, int times) {
            this.num = num;
            this.times = times;
        }
    }



    public static long minSumSquareDiff(int[] nums1, int[] nums2, int k1, int k2) {
        int size = nums1.length;

        Map<Integer,Pair> mappings = new HashMap<>();

        for(int i=0; i<size; i++){
            Integer num = Math.abs(nums1[i] - nums2[i]);
            Pair p = mappings.get(num);
            if(p == null){ p = new Pair(num, 0); mappings.put(num, p); }
            p.times++;
        }


        List<Integer> nums = new ArrayList<>();
        nums.addAll( mappings.keySet() );
        Collections.sort( nums );

        Deque<Pair> pairs = new ArrayDeque<>();
        for (int i = 0; i < nums.size(); i++) {
            pairs.push(mappings.get(nums.get(i)));
        }


        int k = k1 + k2;

        while (k != 0 && !pairs.isEmpty()){
            Pair p = pairs.pop();

            if(p.num == 0){ continue; }

            Integer nn = p.num - 1;
            Pair np = mappings.get(nn);

            if(k >= p.times){
                k = k - p.times;
                if( np == null ){
                    np = new Pair(nn, p.times);
                    mappings.put(nn, np);
                    pairs.push( np );
                }else{
                    np.times = np.times + p.times;
                }
            }else{
                if( np == null ){
                    np = new Pair(nn, k);
                    mappings.put(nn, np);
                    pairs.push( np );
                }else{
                    np.times = np.times + k;
                }
                p.times = p.times - k;
                pairs.push(p);
                k = 0;
            }
        }

        long sum = 0;
        while (!pairs.isEmpty()){
            Pair p = pairs.pop();
            sum = sum + (long)p.num * p.num * p.times;
        }

        return sum;
    }
}
